{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Distributed Data Processing using SageMaker Processing and DJL Spark Container\n",
    "\n",
    "This example notebook demonstrates on how to use Amazon SageMaker Processing with DJL Spark docker image to run distributed deep learning inference on large datasets. DJL Spark docker image is a pre-built image that includes the Deep Java Library (DJL) and other dependencies needed to run distributed data processing jobs on Amazon SageMaker. DJL is an open-source Java-based deep learning library, designed to be easy to use and compatible with existing deep learning ecosystems.\n",
    "\n",
    "By using these two services together, you can easily run distributed deep learning inference on large datasets in a scalable and cost-effective manner."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Contents\n",
    "\n",
    "1. [Setup](#Setup)\n",
    "1. [Push the Container to ECR](#Push-the-Container-to-ECR)\n",
    "1. [Prepare Processing Script](#Prepare-Processing-Script)\n",
    "1. [Run the SageMaker Processing Job](#Run-the-SageMaker-Processing-Job)\n",
    "1. [Monitor and Analyze Your Job](#Monitor-and-Analyze-Your-Job)\n",
    "1. [Validate Data Processing Results](#Validate-Data-Processing-Results)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Install the SageMaker Python SDK\n",
    "\n",
    "First, install the SageMaker Python SDK."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install sagemaker"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Setup account and role\n",
    "\n",
    "Next, you'll define the account, region and role that will be used to run the SageMaker Processing job."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sagemaker\n",
    "from time import gmtime, strftime\n",
    "\n",
    "sagemaker_session = sagemaker.Session()\n",
    "account_id = sagemaker_session.account_id()\n",
    "region = sagemaker_session.boto_region_name\n",
    "role = sagemaker.get_execution_role()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Push the Container to ECR\n",
    "\n",
    "The following section pulls the DJL Spark docker image from dockerhub and pushes to your ECR."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "docker_registry=\"deepjavalibrary\"\n",
    "repository_name=\"djl-spark\"\n",
    "tag=\"0.27.0-cpu\"\n",
    "ecr_registry=\"{}.dkr.ecr.{}.amazonaws.com\".format(account_id, region)\n",
    "\n",
    "# Pull the DJL Spark image\n",
    "!docker pull $docker_registry/$repository_name:$tag\n",
    "\n",
    "# Create ECR repository and push the image to your ECR\n",
    "!$(aws ecr get-login --region $region --registry-ids $account_id --no-include-email)\n",
    "repository_query = !(aws ecr describe-repositories --repository-names $repository_name)\n",
    "if repository_query[0] == '':\n",
    "    !(aws ecr create-repository --repository-name $repository_name)\n",
    "!docker tag $docker_registry/$repository_name:$tag $ecr_registry/$repository_name:$tag\n",
    "!docker push $ecr_registry/$repository_name:$tag"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prepare Processing Script\n",
    "\n",
    "The source for the processing script is in the cell below. The cell uses the `%%writefile` directive to save this file locally.\n",
    "\n",
    "The code reads the input dataset, then applies the opt-125m model to perform text generation. The output is then written to the output bucket."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%writefile ./code/process.py\n",
    "import argparse\n",
    "import os\n",
    "\n",
    "from pyspark.sql import SparkSession\n",
    "from pyspark.sql.functions import col, substring_index\n",
    "from djl_spark.task.text import TextGenerator\n",
    "\n",
    "\n",
    "def main():\n",
    "    parser = argparse.ArgumentParser(description=\"app inputs and outputs\")\n",
    "    parser.add_argument(\"--s3_input_bucket\", type=str, help=\"s3 input bucket\")\n",
    "    parser.add_argument(\"--s3_input_key_prefix\", type=str, help=\"s3 input key prefix\")\n",
    "    parser.add_argument(\"--s3_output_bucket\", type=str, help=\"s3 output bucket\")\n",
    "    parser.add_argument(\"--s3_output_key_prefix\", type=str, help=\"s3 output key prefix\")\n",
    "    args = parser.parse_args()\n",
    "\n",
    "    spark = SparkSession.builder.appName(\"sm-spark-djl-text-generation\").getOrCreate()\n",
    "\n",
    "    # Input\n",
    "    df = spark.read.json(\"s3://\" + os.path.join(args.s3_input_bucket, args.s3_input_key_prefix))\n",
    "\n",
    "    # Truncate the input to 8 words to make the input shorter than max_length 30\n",
    "    df = df.select(substring_index(col(\"inputs\"), \" \", 8).alias(\"inputs\"))\n",
    "\n",
    "    # Text generation\n",
    "    generator = TextGenerator(input_col=\"inputs\",\n",
    "                              output_col=\"prediction\",\n",
    "                              hf_model_id=\"facebook/opt-125m\")\n",
    "    outputDf = generator.generate(df, do_sample=True, max_length=30)\n",
    "    outputDf.write.mode(\"overwrite\").parquet(\"s3://\" + os.path.join(args.s3_output_bucket, args.s3_output_key_prefix))\n",
    "\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    main()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Run the SageMaker Processing Job"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we'll create a `PySparkProcessor` with the following parameters:\n",
    "\n",
    "* `base_job_name`: Set the prefix for processing name to \"sm-spark-djl\".\n",
    "* `image_uri`: Set the URI of the Docker image to the image that uploaded above. \n",
    "* `role`: Set the AWS IAM role to use for the processing job.\n",
    "* `instance_count`: Set the number of instances to run the processing job to 2.\n",
    "* `instance_type`: Set the type of EC2 instance to use for processing to \"ml.m5.2xlarge\".\n",
    "\n",
    "We also set a Spark configuration:\n",
    "\n",
    "* `spark.executor.memory`: Set the amount of memory to use per executor process to 2g.\n",
    "* `spark.executor.cores`: Set the number of cores to use on each executor to 2.\n",
    "* `spark.sql.files.maxPartitionBytes`: Set the maximum number of bytes to pack into a single partition to 2097152 (1MB). The default value for this configuration is 134217728 (128 MB). We set this configuration to a lower number to increase parallelism.\n",
    "\n",
    "Then, the code calls the `run` method of the processor to start the processing job. It passes in the following parameters:\n",
    "\n",
    "* `submit_app`: The path of the processing script to submit to Spark.\n",
    "* `arguments`: List of string arguments to be passed to the processing job, including the input and output location. The input dataset we use is 1000 records from the [Amazon Reviews](https://cseweb.ucsd.edu/~jmcauley/datasets.html#amazon_reviews) dataset.\n",
    "* `configuration`: Spark configuration in above.\n",
    "* `spark_event_logs_s3_uri`: S3 path where spark application events will be published to.\n",
    "* `logs`: Set whether to show the logs produced by the job to False.\n",
    "* `wait`: Set wait until the job completes to True."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sagemaker.spark.processing import PySparkProcessor\n",
    "\n",
    "input_bucket = \"alpha-djl-demos\"\n",
    "input_prefix = \"dataset/nlp/amazon_reviews\"\n",
    "\n",
    "timestamp_prefix = strftime(\"%Y-%m-%d-%H-%M-%S\", gmtime())\n",
    "prefix = \"sagemaker/spark-processing-djl-demo/{}\".format(timestamp_prefix)\n",
    "output_bucket = sagemaker_session.default_bucket()\n",
    "output_prefix = f\"{prefix}/output\"\n",
    "\n",
    "image_uri = \"{}/{}:{}\".format(ecr_registry, repository_name, tag)\n",
    "\n",
    "# Run the processing job\n",
    "spark_processor = PySparkProcessor(\n",
    "    base_job_name=\"sm-spark-djl\",\n",
    "    image_uri=image_uri,\n",
    "    role=role,\n",
    "    instance_count=2,\n",
    "    instance_type=\"ml.m5.2xlarge\"\n",
    ")\n",
    "\n",
    "configuration = [\n",
    "    {\n",
    "        \"Classification\": \"spark-defaults\",\n",
    "        \"Properties\": {\"spark.executor.memory\": \"2g\", \"spark.executor.cores\": \"2\", \"spark.sql.files.maxPartitionBytes\": \"2097152\"}\n",
    "    }\n",
    "]\n",
    "\n",
    "spark_processor.run(\n",
    "    submit_app=\"./code/process.py\",\n",
    "    arguments=[\n",
    "        \"--s3_input_bucket\", input_bucket,\n",
    "        \"--s3_input_key_prefix\", input_prefix,\n",
    "        \"--s3_output_bucket\", output_bucket,\n",
    "        \"--s3_output_key_prefix\", output_prefix,\n",
    "    ],\n",
    "    configuration=configuration,\n",
    "    spark_event_logs_s3_uri=\"s3://{}/{}/spark_event_logs\".format(output_bucket, prefix),\n",
    "    logs=False, # Do not show the logs produced by the job\n",
    "    wait=True # Wait until the job completes\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Monitor and Analyze Your Job\n",
    "\n",
    "Next, by calling `start_history_server()`, you can start the Spark history server and access the Spark UI to view details about the Spark application. This is useful for debugging and troubleshooting, as well as for monitoring the performance and behavior of your Spark processing job."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "spark_processor.start_history_server()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "After viewing the Spark UI, you can terminate the history server before proceeding."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "spark_processor.terminate_history_server()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Validate Data Processing Results\n",
    "\n",
    "Next, validate the output of the Spark job by ensuring that the output URI contains the Spark `_SUCCESS` file along with the output files."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Output files in s3://{}/{}/\".format(output_bucket, output_prefix))\n",
    "!aws s3 ls s3://$output_bucket/$output_prefix/"
   ]
  }
 ],
 "metadata": {
  "instance_type": "ml.t3.medium",
  "kernelspec": {
   "display_name": "conda_python3",
   "language": "python",
   "name": "conda_python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
